/*
 * black_box_reduction.h:
 * Logistic regression using our black-box reduction
 * to turnstile l2 point query. JL matrix is used to
 * compress weights. Then an estimate of the weight
 * vector is fed to a Countsketch matrix. Recovery is
 * again done through the standard CCF02 recovery procedure.
 *
 * Implementation is based on logistic_sketch.h/jl_recovery_sketch.h
 * and logistic_sketch.cpp/jl_recovery_sketch.cpp.
 */

#ifndef BLACK_BOX_REDUCTION_H_
#define BLACK_BOX_REDUCTION_H_

#include <vector>
#include "binary_estimator.h"
#include "hash.h"
#include "countsketch.h"

namespace wmsketch {

class BlackBoxReduction : public BinaryEstimator {

public:
	static const uint32_t MAX_LOG2_WIDTH = 31;

private:

	// JL matrix and sketched weights data structures
	float** weights_;
	float bias_;
	const float lr_init_;
	const float l2_reg_;
	float scale_;
	uint64_t t_;
	const uint32_t depth_;
	uint32_t width_mask_;
	hash::TabulationHash hash_fn_;
	std::vector<uint32_t> hash_buf_;
	std::vector<float> weight_sums_;

	// Countsketch data structure
	// Note that the depth, width, etc. are
	// the same for the JL matrix and Countsketch
	// The scale is also the same.
	CountSketch point_query_l2_;

public:
	BlackBoxReduction(
			uint32_t log2_width,
			uint32_t depth,
			int32_t seed,
			float lr_init = 0.1,
			float l2_reg = 1e-3);
	~BlackBoxReduction() override;
	float get(uint32_t key) override;
	bool predict(uint32_t key);
	bool predict(const std::vector<std::pair<uint32_t, float>>& x);
	bool update(uint32_t key, bool label) override;
	bool update(const std::vector<std::pair<uint32_t, float>>& x, bool label) override;
	bool update(std::vector<float>& new_weights, const std::vector<std::pair<uint32_t, float>>& x, bool label) override;
	float bias() override;
	float scale();

private:
	void get_weights(const std::vector<std::pair<uint32_t, float>>& x);
	float dot(const std::vector<std::pair<uint32_t, float>>& x);
};

}

#endif

